{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interactive Inference Example: Text to Speech to Text\n",
    "\n",
    "This example shows how to set up interactive inference to demo OpenSeq2Seq models. This example will convert text to spoken English via a Text2Speech model and then back to English text via a Speech2Text model.\n",
    "\n",
    "Requirements:\n",
    "* checkpoints for both model\n",
    "* configs for both models\n",
    "\n",
    "Steps:\n",
    "1. Put the Text2Speech checkpoint and config inside a new directory\n",
    "    1. For this example, it is assumed to be inside the Infer_T2S subdirectory\n",
    "2. Put the Speech2Text checkpoint and config inside a new directory\n",
    "    1. For this example, it is assumed to be inside the Infer_S2T subdirectory\n",
    "3. Run jupyter notebook and run all cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "import librosa\n",
    "\n",
    "import numpy as np\n",
    "import scipy.io.wavfile as wave\n",
    "import tensorflow as tf\n",
    "\n",
    "from open_seq2seq.utils.utils import deco_print, get_base_config, check_logdir,\\\n",
    "                                     create_logdir, create_model, get_interactive_infer_results\n",
    "from open_seq2seq.models.text2speech import save_audio\n",
    "\n",
    "# Define the command line arguments that one would pass to run.py here\n",
    "args_S2T = [\"--config_file=Infer_S2T/config.py\",\n",
    "        \"--mode=interactive_infer\",\n",
    "        \"--logdir=Infer_S2T/\",\n",
    "        \"--batch_size_per_gpu=1\",\n",
    "]\n",
    "args_T2S = [\"--config_file=Infer_T2S/config.py\",\n",
    "        \"--mode=interactive_infer\",\n",
    "        \"--logdir=Infer_T2S/\",\n",
    "        \"--batch_size_per_gpu=1\",\n",
    "]\n",
    "\n",
    "# A simpler version of what run.py does. It returns the created model and its saved checkpoint\n",
    "def get_model(args, scope):\n",
    "    with tf.variable_scope(scope):\n",
    "        args, base_config, base_model, config_module = get_base_config(args)\n",
    "        checkpoint = check_logdir(args, base_config)\n",
    "        model = create_model(args, base_config, config_module, base_model, None)\n",
    "    return model, checkpoint\n",
    "\n",
    "model_S2T, checkpoint_S2T = get_model(args_S2T, \"S2T\")\n",
    "model_T2S, checkpoint_T2S = get_model(args_T2S, \"T2S\")\n",
    "\n",
    "# Create the session and load the checkpoints\n",
    "sess_config = tf.ConfigProto(allow_soft_placement=True)\n",
    "sess_config.gpu_options.allow_growth = True\n",
    "sess = tf.InteractiveSession(config=sess_config)\n",
    "vars_S2T = {}\n",
    "vars_T2S = {}\n",
    "for v in tf.get_collection(tf.GraphKeys.VARIABLES):\n",
    "    if \"S2T\" in v.name:\n",
    "        vars_S2T[\"/\".join(v.op.name.split(\"/\")[1:])] = v\n",
    "    if \"T2S\" in v.name:\n",
    "        vars_T2S[\"/\".join(v.op.name.split(\"/\")[1:])] = v\n",
    "saver_S2T = tf.train.Saver(vars_S2T)\n",
    "saver_T2S = tf.train.Saver(vars_T2S)\n",
    "saver_S2T.restore(sess, checkpoint_S2T)\n",
    "saver_T2S.restore(sess, checkpoint_T2S)\n",
    "\n",
    "# line = \"I was trained using Nvidia's Open Sequence to Sequence framework.\"\n",
    "\n",
    "# Define the inference function\n",
    "n_fft = model_T2S.get_data_layer().n_fft\n",
    "sampling_rate = model_T2S.get_data_layer().sampling_rate\n",
    "def infer(line):\n",
    "    print(\"Input English\")\n",
    "    print(line)\n",
    "    \n",
    "    # Generate speech\n",
    "    results = get_interactive_infer_results(model_T2S, sess, model_in=[line])\n",
    "    prediction = results[1][1][0]\n",
    "    audio_length = results[1][4][0]\n",
    "    prediction = prediction[:audio_length-1,:]\n",
    "    prediction = model_T2S.get_data_layer().get_magnitude_spec(prediction, is_mel=True)\n",
    "    wav = save_audio(prediction, \"unused\", \"unused\", sampling_rate=sampling_rate, save_format=\"np.array\", n_fft=n_fft)\n",
    "    audio = IPython.display.Audio(wav, rate=sampling_rate)\n",
    "    wav = librosa.core.resample(wav, sampling_rate, 16000)\n",
    "    print(\"Generated Audio\")\n",
    "    IPython.display.display(audio)\n",
    "\n",
    "    if model_T2S.get_data_layer()._both:\n",
    "        mag_prediction = results[1][5][0]\n",
    "        mag_prediction = mag_prediction[:audio_length-1,:]\n",
    "        mag_prediction = model_T2S.get_data_layer().get_magnitude_spec(mag_prediction)\n",
    "        wav = save_audio(mag_prediction, \"unused\", \"unused\", sampling_rate=sampling_rate, save_format=\"np.array\", n_fft=n_fft)\n",
    "        audio = IPython.display.Audio(wav, rate=sampling_rate)\n",
    "        wav = librosa.core.resample(wav, sampling_rate, 16000)\n",
    "        print(\"Generated Audio from magnitude spec\")\n",
    "        IPython.display.display(audio)\n",
    "\n",
    "    \n",
    "\n",
    "    # Recognize speech\n",
    "    model_in = wav\n",
    "    results = get_interactive_infer_results(model_S2T, sess, model_in=[model_in])\n",
    "    english_recognized = results[0][0]\n",
    "\n",
    "    print(\"Recognized Speech\")\n",
    "    print(english_recognized)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    line = input()\n",
    "    if line == \"\":\n",
    "        break\n",
    "    IPython.display.clear_output()\n",
    "    infer(line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
